"""
Tests fallback neighbor search on a small 2d grid
that was calculated manually
"""

import numpy as np
import paddle
import pytest

from ..simple_neighbor_search import simple_neighbor_search

# Manually-calculated CSR list of neighbors
# in a 5x5 grid on [0,1] X [0,1] for radius=0.3

indices = [
    0,
    1,
    5,
    0,
    1,
    2,
    6,
    1,
    2,
    3,
    7,
    2,
    3,
    4,
    8,
    3,
    4,
    9,
    0,
    5,
    6,
    10,
    1,
    5,
    6,
    7,
    11,
    2,
    6,
    7,
    8,
    12,
    3,
    7,
    8,
    9,
    13,
    4,
    8,
    9,
    14,
    5,
    10,
    11,
    15,
    6,
    10,
    11,
    12,
    16,
    7,
    11,
    12,
    13,
    17,
    8,
    12,
    13,
    14,
    18,
    9,
    13,
    14,
    19,
    10,
    15,
    16,
    20,
    11,
    15,
    16,
    17,
    21,
    12,
    16,
    17,
    18,
    22,
    13,
    17,
    18,
    19,
    23,
    14,
    18,
    19,
    24,
    15,
    20,
    21,
    16,
    20,
    21,
    22,
    17,
    21,
    22,
    23,
    18,
    22,
    23,
    24,
    19,
    23,
    24,
]

splits = [
    0,
    3,
    7,
    11,
    15,
    18,
    22,
    27,
    32,
    37,
    41,
    45,
    50,
    55,
    60,
    64,
    68,
    73,
    78,
    83,
    87,
    90,
    94,
    98,
    102,
    105,
]


def test_fallback_nb_search():
    mesh_grid = np.stack(
        np.meshgrid(*[np.linspace(0, 1, 5) for _ in range(2)], indexing="ij"), axis=-1
    )
    coords = paddle.to_tensor(
        mesh_grid.reshape(-1, 2)
    )  # reshape into n**d x d coord points
    return_dict = simple_neighbor_search(data=coords, queries=coords, radius=0.3)

    assert return_dict["neighbors_index"].tolist() == indices
    assert return_dict["neighbors_row_splits"].tolist() == splits
